/**
 * Copyright 2019-2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "graph_optimizer/graph_fusion/fusion_pass_manager/builtin_pass/node_optimize/checker/concat_optimize_checker.h"
#include <external/graph/types.h>
#include <algorithm>
#include "common/configuration.h"

namespace fe {
bool ConcatOptimizeChecker::Check(const ge::NodePtr &node_ptr) {
  return !IsInputFromSameNode(node_ptr) && IsDimC(node_ptr, CONCAT_DIM, true) && IsDimCAligned(node_ptr) &&
         is_pre_node_valid(node_ptr) && is_next_node_valid(node_ptr, 1, false) && !IsDCorMDC() &&
         IsNCHWOrNHWC(node_ptr) && IsInputNotData(node_ptr);
}

bool ConcatOptimizeChecker::CheckWithQuant(const ge::NodePtr &node_ptr) {
  return !IsInputFromSameNode(node_ptr) && IsDimC(node_ptr, CONCAT_DIM, true) && IsDimCAlignedWithQuant(node_ptr) &&
         is_pre_node_valid(node_ptr) && is_next_node_valid(node_ptr, 1, true) && !IsDCorMDC() && IsNCHWOrNHWC(node_ptr);
}

bool ConcatOptimizeChecker::IsDCorMDC() {
  bool is_d_cor_md_c = fe::Configuration::Instance(fe::AI_CORE_NAME).IsDCorMDCSoc();

  string precision_mode_str = fe::Configuration::Instance(AI_CORE_NAME).GetPrecisionModeStr();
  return is_d_cor_md_c && precision_mode_str != FORCE_FP16;
}

bool ConcatOptimizeChecker::IsNCHWOrNHWC(const ge::NodePtr &node_ptr) {
  for (ge::GeTensorDesc &input_desc : node_ptr->GetOpDesc()->GetAllInputsDesc()) {
    if (input_desc.GetOriginFormat() != ge::FORMAT_NCHW && input_desc.GetOriginFormat() != ge::FORMAT_NHWC) {
      return false;
    }
    if (input_desc.GetShape().GetDimNum() != CONCAT_SHAPE_DIM_DEFAULT) {
      FE_LOGD("The input dimension of the concat operator must be 4.");
      return false;
    }
  }
  return true;
}

bool ConcatOptimizeChecker::IsInputFromSameNode(const ge::NodePtr &node_ptr) {
  string node_name = node_ptr->GetName();
  ge::OpDescPtr op_desc_ptr = node_ptr->GetOpDesc();
  size_t input_size = op_desc_ptr->GetInputsSize();
  if (input_size < 2) {
    return false;
  }

  // 1. get the pre_op_desc_ptr0
  ge::NodePtr pre_node_ptr0;
  Status status = NodeOptimizeUtils::GetPreNode(node_ptr, 0, pre_node_ptr0);
  if (status != SUCCESS) {
    FE_LOGD("Node[%s]: get the previous node of the input0 not success.", node_name.c_str());
    return false;
  }
  ge::OpDescPtr pre_op_desc_ptr0 = pre_node_ptr0->GetOpDesc();

  // 2. check the other inputs
  for (size_t i = 1; i != input_size; ++i) {
    ge::NodePtr pre_node;
    status = NodeOptimizeUtils::GetPreNode(node_ptr, i, pre_node);
    if (status != SUCCESS) {
      FE_LOGD("Node[%s]: get the previous node of the input [%zu] not success.", node_name.c_str(), i);
      return false;
    }
    if (pre_node->GetOpDesc() != pre_op_desc_ptr0) {
      return false;
    }
  }
  FE_LOGD("Node[%s]: all inputs are from the same node, check failed.", node_name.c_str());
  return true;
}

bool ConcatOptimizeChecker::IsDimCAligned(const ge::NodePtr &node_ptr) {
  string node_name = node_ptr->GetName();
  ge::OpDescPtr op_desc_ptr = node_ptr->GetOpDesc();
  size_t input_size = op_desc_ptr->GetInputsSize();
  for (size_t i = 0; i != input_size; ++i) {
    // 1. do nothing for the last one
    if (i == input_size - 1) {
      continue;
    }

    // 2. get the dim_c
    ge::GeTensorDesc tensor_desc = op_desc_ptr->GetInputDesc(i);
    int dim_c = 0;
    Status status = GetDimC(tensor_desc, dim_c);
    if (status != SUCCESS) {
      FE_LOGD("Node[%s]: get the dim C of the input [%zu] not success, check failed.",
          node_name.c_str(), i);
      return false;
    }

    // 3. check the dim_c
    if (!IsDimCOfInputAligned(tensor_desc, dim_c, false)) {
      FE_LOGD("Node[%s]: the dim C of the input [%zu] is not aliged, check failed.", node_name.c_str(), i);
      return false;
    }
  }
  return true;
}

bool ConcatOptimizeChecker::IsDimCAlignedWithQuant(const ge::NodePtr &node_ptr) {
  string node_name = node_ptr->GetName();
  ge::OpDescPtr op_desc_ptr = node_ptr->GetOpDesc();
  size_t input_size = op_desc_ptr->GetInputsSize();
  for (size_t i = 0; i != input_size; ++i) {
    // 1. do nothing for the last one
    if (i == input_size - 1) {
      continue;
    }

    // 2. get the dim_c
    ge::GeTensorDesc tensor_desc = op_desc_ptr->GetInputDesc(i);
    int dim_c = 0;
    Status status = GetDimC(tensor_desc, dim_c);
    if (status != SUCCESS) {
      FE_LOGD("Node[%s]: get the dim C of the input [%zu] not success, check failed.",
          node_name.c_str(), i);
      return false;
    }

    // 3. check the dim_c
    ge::NodePtr concat_next_node = nullptr;
    bool has_quant = false;
    if (node_ptr->GetOutDataAnchor(0) != nullptr && !node_ptr->GetOutDataAnchor(0)->GetPeerInDataAnchors().empty()) {
      ge::OutDataAnchor::Vistor<ge::InDataAnchorPtr> peer_in_data_anchors =
          node_ptr->GetOutDataAnchor(0)->GetPeerInDataAnchors();
      ge::InDataAnchorPtr in_data_anchor_ptr = peer_in_data_anchors.at(0);
      if (in_data_anchor_ptr != nullptr &&
          in_data_anchor_ptr->GetOwnerNode() != nullptr &&
          in_data_anchor_ptr->GetOwnerNode()->GetType() == QUANT) {
        has_quant = true;
      }
    }

    if (!IsDimCOfInputAligned(tensor_desc, dim_c, has_quant)) {
      FE_LOGD("Node[%s]: the dim C of the input [%zu] is not aliged, check failed.", node_name.c_str(), i);
      return false;
    }
  }
  return true;
}

bool ConcatOptimizeChecker::is_pre_node_valid(const ge::NodePtr &node_ptr) {
  string node_name = node_ptr->GetName();
  ge::OpDescPtr op_desc_ptr = node_ptr->GetOpDesc();
  size_t input_size = op_desc_ptr->GetInputsSize();
  for (size_t i = 0; i != input_size; ++i) {
    ge::NodePtr pre_node_ptr;
    Status status = NodeOptimizeUtils::GetPreNode(node_ptr, i, pre_node_ptr);
    if (status != SUCCESS) {
      FE_LOGD("Node[%s]: get the previous node of the input [%zu] not success, check failed.",
              node_ptr->GetName().c_str(), i);
      return false;
    }
    ge::OpDescPtr pre_op_desc_ptr = pre_node_ptr->GetOpDesc();
    bool is_continous_input = false;
    bool is_continous_output = false;
    bool is_ref = false;
    bool no_task = false;
    bool output_reuse_input = false;
    bool no_padding_continuous_input = false;
    bool no_padding_continuous_output = false;
    (void)ge::AttrUtils::GetBool(pre_op_desc_ptr, ge::ATTR_NAME_CONTINUOUS_INPUT, is_continous_input);
    (void)ge::AttrUtils::GetBool(pre_op_desc_ptr, ge::ATTR_NAME_CONTINUOUS_OUTPUT, is_continous_output);
    (void)ge::AttrUtils::GetBool(pre_op_desc_ptr, ge::ATTR_NAME_REFERENCE, is_ref);
    (void)ge::AttrUtils::GetBool(pre_op_desc_ptr, ge::ATTR_NAME_NOTASK, no_task);
    (void)ge::AttrUtils::GetBool(pre_op_desc_ptr, ge::ATTR_NAME_OUTPUT_REUSE_INPUT, output_reuse_input);
    (void)ge::AttrUtils::GetBool(pre_op_desc_ptr, ge::ATTR_NAME_NOPADDING_CONTINUOUS_INPUT,
                                 no_padding_continuous_input);
    (void)ge::AttrUtils::GetBool(pre_op_desc_ptr, ge::ATTR_NAME_NOPADDING_CONTINUOUS_OUTPUT,
                                 no_padding_continuous_output);
    if (is_continous_input || is_continous_output || is_ref || no_task || output_reuse_input ||
        no_padding_continuous_input || no_padding_continuous_output) {
      FE_LOGD("Node[%s]: the previous node [%s] is not supported, check failed.", node_name.c_str(),
              pre_node_ptr->GetName().c_str());
      return false;
    }
  }
  return true;
}

bool ConcatOptimizeChecker::is_next_node_valid(ge::NodePtr concat_node, uint32_t depth, bool has_relu) {
  for (auto &output_anchor : concat_node->GetAllOutDataAnchors()) {
    auto peer_in_anchors = output_anchor->GetPeerInDataAnchors();
    for (size_t i = 0; i < peer_in_anchors.size(); i++) {
      ge::NodePtr next_node = peer_in_anchors.at(i)->GetOwnerNode();
      if (next_node == nullptr) {
        return false;
      }
      ge::OpDescPtr next_node_desc = next_node->GetOpDesc();
      if (next_node == nullptr) {
        return false;
      }
      uint32_t in_data_anchor_index = peer_in_anchors.at(i)->GetIdx();

      string next_node_name = next_node_desc->GetName();
      ge::GeTensorDescPtr geTensorDescPtr = next_node_desc->MutableInputDesc(in_data_anchor_index);
      if (next_node == nullptr) {
        return false;
      }
      ge::Format storage_format = ge::FORMAT_RESERVED;
      int64_t format = ge::FORMAT_RESERVED;
      (void)ge::AttrUtils::GetInt(*geTensorDescPtr, ge::ATTR_NAME_STORAGE_FORMAT, format);
      storage_format = static_cast<ge::Format>(format);
      bool no_need_optimize = next_node_desc->GetType() == NETOUTPUT &&
                              (ge::GetPrimaryFormat(geTensorDescPtr->GetFormat()) == ge::FORMAT_NC1HWC0 ||
                               storage_format == ge::FORMAT_NC1HWC0);
      if (no_need_optimize) {
        FE_LOGD("Next node %s is netoutput, %s can not optimize.", next_node_name.c_str(),
                concat_node->GetName().c_str());
        return false;
      }
      if (depth > 0 && (next_node_desc->GetType() == QUANT ||
                        (has_relu && next_node_desc->GetType() == LEAKYRELU) ||
                        (has_relu && next_node_desc->GetType() == RELU))) {
        return is_next_node_valid(next_node, depth - 1, has_relu);
      }
    }
  }
  return true;
}
}  // namespace fe
